#!/usr/bin/env python3

import functools
import itertools
import os
import random
from collections import deque
from copy import deepcopy
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence

import numpy as np
import torch
import wandb
from rpi import logger
from rpi.agents.base import Agent
from rpi.agents.mamba import (ActivePolicySelector, ActiveStateExplorer,
                                MaxValueFn, ValueEnsemble)
from rpi.agents.ppo import update_critic_ensemble
from rpi.evaluation import eval_fn
from rpi.helpers import set_random_seed, to_torch
from rpi.helpers.data import flatten
from rpi.helpers.env import rollout_single_ep, rollout
from rpi.helpers.initializers import ortho_init
from rpi.nn.empirical_normalization import EmpiricalNormalization
from rpi.policies import (GaussianHeadWithStateIndependentCovariance,
                            SoftmaxCategoricalHead)
from rpi.scripts.sweep.default_args import Args
from rpi.value_estimations import (
    _attach_advantage_and_value_target_to_episode,
    _attach_log_prob_to_episodes, _attach_return_and_value_target_to_episode,
    _attach_value_to_episodes)
from torch import nn

from .train import tolist


def inspect_ensemble(make_env: Callable, experts: List[Agent], max_episode_len = 1000):
    import os
    logger.info('cvd', os.environ['CUDA_VISIBLE_DEVICES'])

    env = make_env()
    pret_num_rollouts = Args.pret_num_rollouts  # 16 for CartPole, DIP, 512 for HalfCheetah and Ant

    ## Let's evaluate the experts first!
    for expert_idx, expert in enumerate(experts):
        stats = eval_fn(make_env, expert, max_episode_len, num_episodes=100)
        logs = {f'expert-{expert_idx}/{key}': val for key, val in stats.items() if not key.startswith('_')}
        wandb.log({'step': 0, **logs})

    ## For each expert k, collect data D^k by rolling out pi^k
    expert_rollouts = [deque(maxlen=Args.expert_buffer_size) for _ in experts]  # 100 for CartPole, DIP, 2 for HalfCheetah and Ant
    for _ in range(pret_num_rollouts):
        for expert_idx, expert in enumerate(experts):
            episode = rollout_single_ep(env, functools.partial(expert.act, mode=Args.deterministic_experts), max_episode_len)

            # BUG: the way to compute value target should depend on Args.expert_tgtval
            #XL: using true V instead of estimated V
            _attach_return_and_value_target_to_episode(episode, gamma=1.0)
            expert_rollouts[expert_idx].append(episode)


    # Generate states from random policy
    randpol_states = []
    for target_step in [3, 5, 10, 12, 14]:
        completed = False
        while not completed:
            random_policy = lambda obs: env.action_space.sample()
            episode = rollout_single_ep(env, random_policy, max_episode_len)
            if len(episode) >= target_step:
                randpol_states.append(episode[-1]['state'])
                completed = True


    # Randomly sample from expert transitions
    _expert_transitions = flatten([flatten(expert_rollout) for expert_rollout in expert_rollouts])
    sampled_states = [trans['state'] for trans in np.random.choice(_expert_transitions, size=10)]
    sampled_states += reversed(randpol_states)
    sampled_states += [np.random.random(size=trans['state'].shape) for trans in np.random.choice(_expert_transitions, size=10)]


    # Inspect expert value function before training
    for i, sample_state in enumerate(sampled_states):
        assert len(experts) == 1
        with torch.no_grad():
            stats = [expert.vfn.forward_stats(to_torch(sample_state).unsqueeze(0), normalize_input=True)
                    for expert in experts]

        wandb.log({
            f'before-pretrain-expert/mean_hist': wandb.Histogram(tolist(stats[0].all_means.squeeze())),
            f'before-pretrain-expert/stddev': stats[0].std,
            f'before-pretrain-expert/input_norm': np.linalg.norm(sample_state),
            f'before-pretrain-expert/std_from_mean': torch.std(stats[0].all_means, dim=0).item(),
            'step': i,
        })

    ## Update value function V^k from D^k  (By a simple Monte Carlo return??)
    for expert_idx, expert in enumerate(experts):
        expert_k_transitions = flatten(expert_rollouts[expert_idx])

        # Update obs_normalizer using the collected expert transitions
        expert.obs_normalizer.experience(to_torch([tr['state'] for tr in expert_k_transitions]))

        # NOTE: num_updates may change the behavior quite a lot.
        _, loss_critic_history = update_critic_ensemble(expert, expert_k_transitions, num_updates=Args.pret_num_updates, batch_size=Args.batch_size, std_from_means=Args.std_from_means)  # 100 for CartPole, DIP
        for i, loss in enumerate(loss_critic_history):
            wandb.log({
                f'pretrain-expert-{expert_idx}/critic': loss,
                f'pretrain-expert-{expert_idx}/num_transitions': len(expert_k_transitions),
                'pret-step': i,
            })

    # Inspect expert value function after training
    for i, sample_state in enumerate(sampled_states):
        assert len(experts) == 1
        with torch.no_grad():
            stats = [expert.vfn.forward_stats(to_torch(sample_state).unsqueeze(0), normalize_input=True)
                    for expert in experts]
        wandb.log({
            f'pretrain-expert/mean_hist': wandb.Histogram(tolist(stats[0].all_means.squeeze())),
            f'pretrain-expert/mean': stats[0].mean,
            f'pretrain-expert/stddev': stats[0].std,
            f'pretrain-expert/input_norm': np.linalg.norm(sample_state),
            f'pretrain-expert/std_from_mean': torch.std(stats[0].all_means, dim=0).item(),
            'step': i,
        })


def main():
    import gym
    from rpi.agents.mamba import MambaAgent

    num_train_steps = Args.num_train_steps

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gamma = Args.gamma  # 0.995
    lmd = Args.lmd  # 0.97
    load_expert_step = Args.load_expert_step

    set_random_seed(Args.seed)

    def make_env(test=False):
        from rpi.helpers import env
        seed = Args.seed if not test else 42 - Args.seed
        return env.make_env('DartCartPole-v1', seed=seed)

    test_env = make_env()
    state_dim = test_env.observation_space.low.size

    if isinstance(test_env.action_space, gym.spaces.Box):
        # Continuous action space
        act_dim = test_env.action_space.low.size
        policy_head = GaussianHeadWithStateIndependentCovariance(
            action_size=act_dim,
            var_type="diagonal",
            var_func=lambda x: torch.exp(2 * x),  # Parameterize log std
            var_param_init=0,  # log std = 0 => std = 1
        )
    else:
        # Discrete action space (assuming categorical)
        act_dim = test_env.action_space.n
        policy_head = SoftmaxCategoricalHead()

    logger.info('obs_dim', state_dim)
    logger.info('act_dim', act_dim)


    from .train import Factory, get_expert
    pi = Factory.create_pi(state_dim, act_dim, policy_head=policy_head)

    obs_normalizer = EmpiricalNormalization(state_dim, clip_threshold=5)
    obs_normalizer.to('cuda')

    # TEMP: Single expert
    expert = get_expert(state_dim, act_dim, deepcopy(policy_head), Path(Args.experts_dir) / f'step_{load_expert_step:06d}.pt',
                        obs_normalizer=obs_normalizer)
    experts = [expert]
    vfn_aggr = MaxValueFn([expert.vfn for expert in experts], obs_normalizers=[expert.obs_normalizer for expert in experts])

    # NOTE: lmd = 1.0 --> Pure RL
    optimizer = torch.optim.Adam(pi.parameters(), lr=1e-3, betas=(0.9, 0.99))

    learner = MambaAgent(pi, vfn_aggr, optimizer, obs_normalizer, gamma=gamma, lambd=lmd)
    learner.to(device)

    inspect_ensemble(make_env, experts)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("sweep_file", help="sweep file")
    parser.add_argument("-l", "--line-number", type=int, help="sweep file")
    args = parser.parse_args()

    # Obtain kwargs from Sweep
    from params_proto.hyper import Sweep
    sweep = Sweep(Args).load(args.sweep_file)
    kwargs = list(sweep)[args.line_number]

    Args._update(kwargs)

    num_gpus = 4
    cvd = args.line_number % num_gpus
    os.environ['CUDA_VISIBLE_DEVICES'] = str(cvd)

    sweep_basename = os.path.splitext(os.path.basename(args.sweep_file))[0]
    wandb.login()
    wandb.init(
        # Set the project where this run will be logged
        project='alops',
        group=sweep_basename,
        config=vars(Args),
    )
    main()
